{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "#  Introduction to the Molecular Attention Transformer.\n",
    "\n",
    "In this tutorial we will learn more about the Molecular Attention Transformer, or MAT. MAT is a model based on transformers, aimed towards performing molecular prediction tasks. MAT is easy to tune and performs quite well relative to other molecular prediction tasks. The weights from MAT are chemically interpretable, thus making the model quite useful. \n",
    "\n",
    "Reference Paper: [Molecular Attention Transformer, Maziarka et. al.](https://arxiv.org/abs/2002.08264)\n",
    "\n",
    "In this tutorial, we will explore how to train MAT, and predict hydration enthalpy values for molecules from the freesolv hydration enthalpy dataset with MAT.\n",
    "\n",
    "## Colab\n",
    "\n",
    "This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/Introduction_to_Molecular_Attention_Transformer.ipynb)\n",
    "\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!pip install --pre deepchem"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Import required modules"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import deepchem as dc\n",
    "from deepchem.models.torch_models import MATModel\n",
    "from deepchem.feat import MATFeaturizer\n",
    "import matplotlib.pyplot as plt"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Molecule Featurization using MATFeaturizer\n",
    "\n",
    "MATFeaturizer is the featurizer intended to be used with the Molecular Attention Transformer, or MAT in short. MATFeaturizer takes a smile string or molecule as input and returns a MATEncoding dataclass object, which contains 3 numpy arrays: Node features matrix, adjacency matrix and distance matrix."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "featurizer = dc.feat.MATFeaturizer()\n",
    "# Let us now take an example array of smile strings and featurize it.\n",
    "smile_string = [\"CCC\"]\n",
    "output = featurizer.featurize(smile_string)\n",
    "print(type(output[0]))\n",
    "print(output[0].node_features)\n",
    "print(output[0].adjacency_matrix)\n",
    "print(output[0].distance_matrix)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "<class 'deepchem.feat.molecule_featurizers.mat_featurizer.MATEncoding'>\n",
      "[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
      "  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
      "  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.\n",
      "  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n",
      "  0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]\n",
      "[[0. 0. 0. 0.]\n",
      " [0. 0. 0. 1.]\n",
      " [0. 0. 0. 1.]\n",
      " [0. 1. 1. 0.]]\n",
      "[[1.e+06 1.e+06 1.e+06 1.e+06]\n",
      " [1.e+06 0.e+00 2.e+00 1.e+00]\n",
      " [1.e+06 2.e+00 0.e+00 1.e+00]\n",
      " [1.e+06 1.e+00 1.e+00 0.e+00]]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Getting the Freesolv Hydration enthalpy dataset\n",
    "\n",
    "We will now acquire the Freesolv Hydration enthalpy dataset from MoleculeNet. If it already exists in the directory, the file will be used. Else, deepchem will automatically download the dataset from its AWS bucket."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "tasks, dataset, transformers = dc.molnet.load_freesolv()\n",
    "train_dataset, val_dataset, test_dataset = dataset\n",
    "train_smiles = train_dataset.ids\n",
    "val_smiles = val_dataset.ids"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "train_dataset"
   ],
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "<DiskDataset X.shape: (513,), y.shape: (513, 1), w.shape: (513, 1), ids: ['CCCCNCCCC' 'CCOC=O' 'CCCCCCCCC' ... 'COC' 'CCCCCCCCBr'\n",
       " 'CCCc1ccc(c(c1)OC)O'], task_names: ['y']>"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training the model\n",
    "\n",
    "Now that we have acquired the dataset to be used and made the necessary imports, we will be invoking the Molecular Attention Transformer in deepchem, called MATModel, and we will be training it. We will be using default parameters for the purposes of this tutorial, however they can be changed anytime according to the user's preferences."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "device = 'cpu'\n",
    "model = MATModel(device = device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "losses, val_losses = [], []"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "%%time\n",
    "max_epochs = 10\n",
    "\n",
    "for epoch in range(max_epochs):\n",
    "    loss = model.fit(train_dataset, nb_epoch = 1, max_checkpoints_to_keep = 1, all_losses = losses)\n",
    "    metric = dc.metrics.Metric(dc.metrics.score_function.rms_score)\n",
    "    val_losses.append(model.evaluate(val_dataset, metrics = [metric])['rms_score']**2)\n",
    "\n",
    "# The warnings are not relevant to this tutorial thus we can safely skip them."
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 20min 48s, sys: 10.7 s, total: 20min 58s\n",
      "Wall time: 2min 41s\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "f, ax = plt.subplots()\n",
    "ax.scatter(range(len(losses)), losses, label='train loss')\n",
    "ax.scatter(range(len(val_losses)), val_losses, label='val loss')\n",
    "plt.legend(loc='upper right');"
   ],
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAYv0lEQVR4nO3df3BU9b3/8ef7G4IJKhIlICZU4nwZlYSQyMqk5TtYiy3+BqfiYNWhjsqdb/latXdSg62Uub2OXLFXy3fqncvQ+sUfLWa4qFy1ei1q0TtWDb9EQL4iqNmAkvA1FK9BILy/f+yBLhI0myV7wn5ejxln97z3nD3vPWNee/ics+eYuyMiImH4b3E3ICIiuaPQFxEJiEJfRCQgCn0RkYAo9EVEAtIv7ga+zuDBg33EiBFxtyEiclxZuXJlm7uXfrne50N/xIgRNDU1xd2GiMhxxcw+7Kqu4R0RkYAo9EVEAqLQFxEJSJ8f0xeR/LVv3z6SySR79uyJu5XjVlFREeXl5RQWFnZrfoW+iMQmmUxy8sknM2LECMws7naOO+7Ozp07SSaTVFRUdGuZvA/9p1a3MO+FTWxr7+CMQcXUTzqbKbVlcbclIsCePXsU+FkwM0477TRaW1u7vUxeh/5Tq1uYtXQdHfs6AWhp72DW0nUACn6RPkKBn51Mt19eH8id98KmQ4F/UMe+Tua9sCmmjkRE4pXXob+tvSOjuoiEpb29nYceeqhHy1566aW0t7d3e/45c+Zw//3392hdx1Jeh/4Zg4ozqotIWL4q9Ds7O7usH/Tcc88xaNCg3mirV31t6JvZ78xsh5m9k1Y71cxeNLP3oseStNdmmdlmM9tkZpPS6mPNbF302nzLwUBe/aSzKS4sOKxWXFhA/aSze3vVItILnlrdwvi5L1HR8Czj577EU6tbsnq/hoYG3n//fWpqaqivr+eVV17hwgsv5Ac/+AGjR48GYMqUKYwdO5bKykoWLFhwaNkRI0bQ1tbGBx98wLnnnsstt9xCZWUl3/ve9+jo+OrRhDVr1lBXV0d1dTVXXXUVn376KQDz589n1KhRVFdXM23aNAD+/Oc/U1NTQ01NDbW1tezevTurz4y7f+V/wATgPOCdtNp9QEP0vAH4p+j5KGAtcAJQAbwPFESvvQl8EzDgj8AlX7dud2fs2LGejSdXJf1b9y73EXc+49+6d7k/uSqZ1fuJyLGzYcOGbs/75Kqkn/PzP/qZdz5z6L9zfv7HrP6mt27d6pWVlYemX375ZR8wYIBv2bLlUG3nzp3u7v755597ZWWlt7W1ubv7mWee6a2trb5161YvKCjw1atXu7v71KlT/dFHHz1iXb/4xS983rx57u4+evRof+WVV9zd/e677/bbbrvN3d2HDRvme/bscXf3Tz/91N3dL7/8cn/ttdfc3X337t2+b9++I967q+0INHkXmfq1e/ruvgL4f18qTwYWRc8XAVPS6ovd/Qt33wpsBsaZ2TBgoLu/HjXzSNoyvWpKbRn/2fAdts69jP9s+I7O2hE5TuXqxIxx48Ydds77/PnzGTNmDHV1dTQ3N/Pee+8dsUxFRQU1NTUAjB07lg8++OCo779r1y7a29u54IILAJg+fTorVqwAoLq6muuuu47HHnuMfv1SJ1eOHz+en/zkJ8yfP5/29vZD9Z7q6Zj+UHffDhA9DonqZUBz2nzJqFYWPf9yXUSkW3J1YsaJJ5546Pkrr7zCn/70J15//XXWrl1LbW1tl78ePuGEEw49LygoYP/+/T1a97PPPsvMmTNZuXIlY8eOZf/+/TQ0NLBw4UI6Ojqoq6vj3Xff7dF7H3SsD+R2NU7vX1Hv+k3MZphZk5k1ZfKjAxHJX71xYsbJJ5/8lWPku3btoqSkhAEDBvDuu+/yl7/8pcfrOuiUU06hpKSEV199FYBHH32UCy64gAMHDtDc3MyFF17IfffdR3t7O5999hnvv/8+o0eP5s477ySRSMQW+p9EQzZEjzuiehIYnjZfObAtqpd3Ue+Suy9w94S7J0pLj7gHgIgEqDdOzDjttNMYP348VVVV1NfXH/H6xRdfzP79+6murubuu++mrq6ux+tKt2jRIurr66murmbNmjXMnj2bzs5Orr/+ekaPHk1tbS133HEHgwYN4sEHH6SqqooxY8ZQXFzMJZdcktW6LTXE/jUzmY0AnnH3qmh6HrDT3eeaWQNwqrv/1Mwqgd8D44AzgOXASHfvNLO3gFuBN4DngP/t7s993boTiYTrJioi+Wnjxo2ce+653Z5fl1XpWlfb0cxWunviy/N+7REBM/sD8G1gsJklgV8Ac4FGM7sJ+AiYCuDu682sEdgA7AdmuvvBIy//E/g/QDGps3f+2JMPJyLhmlJbppDP0teGvrtfe5SXJh5l/nuAe7qoNwFVGXUnIiLHVF7/IldERA6n0BcRCYhCX0QkIAp9EZGAKPRFRDJw0kknZVTvaxT6IiIBUeiLyPHj7UZ4oArmDEo9vt2Y1dvdeeedh11Pf86cOfzqV7/is88+Y+LEiZx33nmMHj2ap59+utvv6e7U19dTVVXF6NGjeeKJJwDYvn07EyZMoKamhqqqKl599VU6Ozv54Q9/eGjeBx54IKvP0x15fY9cEckjbzfCv/8Y9kUXWNvVnJoGqL6mR285bdo0br/9dn70ox8B0NjYyPPPP09RURFPPvkkAwcOpK2tjbq6Oq688spu3Y926dKlrFmzhrVr19LW1sb555/PhAkT+P3vf8+kSZP42c9+RmdnJ59//jlr1qyhpaWFd95J3a4kkztx9ZT29EXk+LD8H/4W+Aft60jVe6i2tpYdO3awbds21q5dS0lJCd/4xjdwd+666y6qq6u56KKLaGlp4ZNPPunWe7722mtce+21FBQUMHToUC644ALeeustzj//fB5++GHmzJnDunXrOPnkkznrrLPYsmULt956K88//zwDBw7s8WfpLoW+iBwfdiUzq3fT1VdfzZIlS3jiiScO3a3q8ccfp7W1lZUrV7JmzRqGDh3a5SWVu3K065lNmDCBFStWUFZWxg033MAjjzxCSUkJa9eu5dvf/ja/+c1vuPnmm7P6LN2h0BeR48Mp5ZnVu2natGksXryYJUuWcPXVVwOpSyoPGTKEwsJCXn75ZT788MNuv9+ECRN44okn6OzspLW1lRUrVjBu3Dg+/PBDhgwZwi233MJNN93EqlWraGtr48CBA3z/+9/nl7/8JatWrcrqs3SHxvRF5PgwcfbhY/oAhcWpehYqKyvZvXs3ZWVlDBs2DIDrrruOK664gkQiQU1NDeecc0633++qq67i9ddfZ8yYMZgZ9913H6effjqLFi1i3rx5FBYWctJJJ/HII4/Q0tLCjTfeyIEDBwC49957s/os3dGtSyvHSZdWFslfmV5ambcbU2P4u5KpPfyJs3t8EDefHNNLK4uI9BnV1yjks6QxfRGRgCj0RSRWfX2Iua/LdPsp9EUkNkVFRezcuVPB30Puzs6dOykqKur2MhrTF5HYlJeXk0wmaW1tjbuV41ZRURHl5d0/bVWhLyKxKSwspKKiIu42gqLhHRGRgCj0RUQCotAXEQmIQl9EJCAKfRGRgCj0RUQCotAXEQmIQl9EJCAKfRGRgCj0RUQCklXom9kdZrbezN4xsz+YWZGZnWpmL5rZe9FjSdr8s8xss5ltMrNJ2bcvIiKZ6HHom1kZ8GMg4e5VQAEwDWgAlrv7SGB5NI2ZjYperwQuBh4ys4Ls2hcRkUxkO7zTDyg2s37AAGAbMBlYFL2+CJgSPZ8MLHb3L9x9K7AZGJfl+kVEJAM9Dn13bwHuBz4CtgO73P0/gKHuvj2aZzswJFqkDGhOe4tkVBMRkRzJZninhNTeewVwBnCimV3/VYt0UevyzglmNsPMmsysSdfZFhE5drIZ3rkI2Orure6+D1gKfAv4xMyGAUSPO6L5k8DwtOXLSQ0HHcHdF7h7wt0TpaWlWbQoIiLpsgn9j4A6MxtgZgZMBDYCy4Dp0TzTgaej58uAaWZ2gplVACOBN7NYv4iIZKjHd85y9zfMbAmwCtgPrAYWACcBjWZ2E6kvhqnR/OvNrBHYEM0/0907s+xfREQyYH39hsSJRMKbmpribkNE5LhiZivdPfHlun6RKyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISkKxC38wGmdkSM3vXzDaa2TfN7FQze9HM3oseS9Lmn2Vmm81sk5lNyr59ERHJRLZ7+r8Gnnf3c4AxwEagAVju7iOB5dE0ZjYKmAZUAhcDD5lZQZbrFxGRDPQ49M1sIDAB+C2Au+9193ZgMrAomm0RMCV6PhlY7O5fuPtWYDMwrqfrFxGRzGWzp38W0Ao8bGarzWyhmZ0IDHX37QDR45Bo/jKgOW35ZFQ7gpnNMLMmM2tqbW3NokUREUmXTej3A84D/sXda4H/IhrKOQrrouZdzejuC9w94e6J0tLSLFoUEZF02YR+Eki6+xvR9BJSXwKfmNkwgOhxR9r8w9OWLwe2ZbF+ERHJUI9D390/BprN7OyoNBHYACwDpke16cDT0fNlwDQzO8HMKoCRwJs9Xb+IiGSuX5bL3wo8bmb9gS3AjaS+SBrN7CbgI2AqgLuvN7NGUl8M+4GZ7t6Z5fpFRCQDWYW+u68BEl28NPEo898D3JPNOkVEpOf0i1wRkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYBkHfpmVmBmq83smWj6VDN70czeix5L0uadZWabzWyTmU3Kdt0iIpKZY7GnfxuwMW26AVju7iOB5dE0ZjYKmAZUAhcDD5lZwTFYv4iIdFNWoW9m5cBlwMK08mRgUfR8ETAlrb7Y3b9w963AZmBcNusXEZHMZLun/yDwU+BAWm2ou28HiB6HRPUyoDltvmRUO4KZzTCzJjNram1tzbJFERE5qMehb2aXAzvcfWV3F+mi5l3N6O4L3D3h7onS0tKetigiIl/SL4tlxwNXmtmlQBEw0MweAz4xs2Huvt3MhgE7ovmTwPC05cuBbVmsX0REMtTjPX13n+Xu5e4+gtQB2pfc/XpgGTA9mm068HT0fBkwzcxOMLMKYCTwZo87FxGRjGWzp380c4FGM7sJ+AiYCuDu682sEdgA7AdmuntnL6xfRESOwty7HFbvMxKJhDc1NcXdhojIccXMVrp74st1/SJXRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgCn0RkYAo9EVEAqLQFxEJiEJfRCQgPQ59MxtuZi+b2UYzW29mt0X1U83sRTN7L3osSVtmlpltNrNNZjbpWHwAERHpvmz29PcDf+/u5wJ1wEwzGwU0AMvdfSSwPJomem0aUAlcDDxkZgXZNC8iIpnpcei7+3Z3XxU93w1sBMqAycCiaLZFwJTo+WRgsbt/4e5bgc3AuJ6uX0REMndMxvTNbARQC7wBDHX37ZD6YgCGRLOVAc1piyWjWlfvN8PMmsysqbW19Vi0KCIiHIPQN7OTgH8Dbnf3v37VrF3UvKsZ3X2BuyfcPVFaWpptiyIiEskq9M2skFTgP+7uS6PyJ2Y2LHp9GLAjqieB4WmLlwPbslm/iIhkJpuzdwz4LbDR3f857aVlwPTo+XTg6bT6NDM7wcwqgJHAmz1dv4iIZK5fFsuOB24A1pnZmqh2FzAXaDSzm4CPgKkA7r7ezBqBDaTO/Jnp7p1ZrF9ERDLU49B399foepweYOJRlrkHuKen6xQRkezoF7kiIgFR6IuIBEShLyISEIW+iEhA8j/0326EB6pgzqDU49uNcXckIhKbbE7Z7PveboR//zHs60hN72pOTQNUXxNfXyIiMcnvPf3l//C3wD9oX0eqLiISoPwO/V3JzOoiInkuv0P/lPLM6iIieS6/Q3/ibCgsPrxWWJyqi4gEKL9Dv/oauGI+nDIcsNTjFfN1EFdEgpXfZ+9AKuAV8iIiQL7v6YuIyGEU+iIiAVHoi4gERKEvIhIQhb6ISEAU+iIiAVHoi4gERKEvIhIQhb6ISEAU+rmgG7mISB+R/5dhiJtu5CIifYj29HubbuQiIn2IQr+X+VFu2HK0uohIb1Lo97JPGJxRXUSkNyn0e9m9e6fyufc/rPa59+fevVNz2sdTq1sYP/clKhqeZfzcl3hqdUtO1y8ifYNCv5c1DfwuDftuJnlgMAfcSB4YTMO+m2ka+N2c9fDU6hZmLV1HS3sHDrS0dzBr6ToFv0iAdPZOL6ufdDazlu5l2d7/cahWXFjAvZPOzlkP817YxHc7/8xP+zdyhrWxzQdz3/5rmPdCf6bUluWsj7eW/SvDV81jiLeyw0ppPq+e86/8u5yt/5C3G1MH0nclU/dLnjhbZ1JJMHIe+mZ2MfBroABY6O5zc91DLh0M1XkvbGJbewdnDCqmftLZOQ3bxF9f5N7ChQywvQCUWxtzCxcy668A38lJD28t+1eqVv6cYtsLBqfTyikrf85bkNvgf7uR/U/fSr/OPanpXc2pachp8PeFL0D1EGYP5u698sZdrsysAPi/wHeBJPAWcK27bzjaMolEwpuamnLUYX76eM5/53Raj6xTyulzNgfTA8Dn/3QOAzq2H1kvHsaAO9/NSQ+HfQFGOrw/74z9x5yFjXrI/x7MbKW7J75cz/WY/jhgs7tvcfe9wGJgco57CM5Q2jKq94YhfmTgp+q56wGgqOPjjOq9YfiqeYf9gQMU216Gr5qnHtRDr/eQ69AvA5rTppNR7TBmNsPMmsysqbW167CQ7rNTyjOq94YdVnqUem5PXd124LSM6r2hL3wBqodwe8h16FsXtSPGl9x9gbsn3D1RWtp1WEgGJs6GwuLDa4XFqXqONJ9XT8eXTl3t8P40n1efsx4AFva/vstTaBf2vz5nPfSFL0D1EG4PuQ79JDA8bboc2JbjHsJTfQ1cMR9OGQ5Y6vGK+Tk9cHn+lX/HO2P/kY8p5YAbH1Oa03HTg2oum8Fsn3HYKbSzfQY1l83IWQ994QtQPYTbQ64P5PYjdSB3ItBC6kDuD9x9/dGW0YFcOdaeWt0S69lUkH62Rhs7bHDMZ4yoh3zs4WgHcnMa+lEjlwIPkjpl83fufs9Xza/QFxHJ3NFCP+fn6bv7c8BzuV6viIjoMgwiIkFR6IuIBEShLyISEIW+iEhAFPoiIgFR6IuIBEShLyISkJz/OCtTZtYKfHgM3mow5PCykn2XtsPfaFukaDuk5Nt2ONPdj7iwT58P/WPFzJq6+nVaaLQd/kbbIkXbISWU7aDhHRGRgCj0RUQCElLoL4i7gT5C2+FvtC1StB1SgtgOwYzpi4hIWHv6IiLBU+iLiAQk70PfzC42s01mttnMGuLuJy5mNtzMXjazjWa23sxui7unOJlZgZmtNrNn4u4lLmY2yMyWmNm70f8X34y7p7iY2R3R38U7ZvYHMyuKu6fektehb2YFwG+AS4BRwLVmNirermKzH/h7dz8XqANmBrwtAG4DNsbdRMx+DTzv7ucAYwh0e5hZGfBjIOHuVaTu6jct3q56T16HPjAO2OzuW9x9L7AYmBxzT7Fw9+3uvip6vpvUH3hubwzbR5hZOXAZsDDuXuJiZgOBCcBvAdx9r7u3x9tVrPoBxdF9vAcA22Lup9fke+iXAc1p00kCDbp0ZjYCqAXeiLeT2DwI/BQ4EHcjMToLaAUejoa5FprZiXE3FQd3bwHuBz4CtgO73P0/4u2q9+R76FsXtaDPUTWzk4B/A25397/G3U+umdnlwA53Xxl3LzHrB5wH/Iu71wL/BQR5zMvMSkiNAFQAZwAnmtn18XbVe/I99JPA8LTpcvL4n21fx8wKSQX+4+6+NO5+YjIeuNLMPiA13PcdM3ss3pZikQSS7n7wX3tLSH0JhOgiYKu7t7r7PmAp8K2Ye+o1+R76bwEjzazCzPqTOjizLOaeYmFmRmr8dqO7/3Pc/cTF3We5e7m7jyD1/8NL7p63e3VH4+4fA81mdnZUmghsiLGlOH0E1JnZgOjvZCJ5fFC7X9wN9CZ3329m/wt4gdQR+d+5+/qY24rLeOAGYJ2ZrYlqd7n7czH2JPG6FXg82iHaAtwYcz+xcPc3zGwJsIrUWW6ryeNLMugyDCIiAcn34R0REUmj0BcRCYhCX0QkIAp9EZGAKPRFRAKi0BcRCYhCX0QkIP8fJ0xNTNbXddgAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     }
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Testing the model\n",
    "\n",
    "Optimally, MAT should be trained for a lot more epochs with a GPU. Due to computational constraints, we train this model for very few epochs in this tutorial. Let us now see how to predict the hydration enthalpy values for molecues now with MAT."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "# We will be predicting the enthalpy value for the smile string we featurized earlier in the MATFeaturizer section.\n",
    "model.predict_on_batch(output)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:165: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  node_features = torch.tensor(data[0]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:166: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  adjacency_matrix = torch.tensor(data[1]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/mat.py:167: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(data[2]).float()\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:171: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(2) + eps)\n",
      "/home/atreyamaj/Desktop/deepchem/deepchem/models/torch_models/layers.py:178: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  distance_matrix = torch.tensor(distance_matrix).squeeze().masked_fill(\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([[-0.32668447]], dtype=float32)"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.8.10",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.10 64-bit"
  },
  "interpreter": {
   "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
